# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import functools

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import apex

import common.metrics
import models


def update_argparser(parser):
    models.update_argparser(parser)
    args, _ = parser.parse_known_args()
    parser.add_argument(
        '--num_blocks',
        help='Number of residual blocks in networks.',
        default=16,
        type=int)
    parser.add_argument(
        '--num_residual_units',
        help='Number of residual units in networks.',
        default=32,
        type=int)
    parser.add_argument(
        '--width_multiplier',
        help='Width multiplier inside residual blocks.',
        default=4,
        type=float)
    parser.add_argument(
        '--temporal_size',
        help='Number of frames for burst input.',
        default=None,
        type=int)
    if args.dataset.startswith('div2k'):
        parser.set_defaults(
            train_epochs=30,
            learning_rate_milestones=(20, 25),
            learning_rate_decay=0.2,
            save_checkpoints_epochs=1,
            lr_patch_size=48,
            train_temporal_size=1,
            eval_temporal_size=1,
        )
    else:
        raise NotImplementedError('Needs to tune hyper parameters for new dataset.')


def get_model_spec(params):
    model = MODEL(params)
    print('# of parameters: ', sum([p.numel() for p in model.parameters()]))
    # optimizer = optim.Adam(model.parameters(), params.learning_rate)
    optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), params.learning_rate)
    # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
    #                                               params.learning_rate_milestones,
    #                                               params.learning_rate_decay)
    lr_scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=0)
    loss_fn = torch.nn.L1Loss()
    metrics = {
        'loss':
            loss_fn,
        'PSNR':
            functools.partial(
                common.metrics.psnr,
                shave=0 if params.scale == 1 else params.scale + 6),
        'PSNR_Y':
            functools.partial(
                common.metrics.psnr_y,
                shave=0 if params.scale == 1 else params.scale),
    }
    return model, loss_fn, optimizer, lr_scheduler, metrics


class MODEL(nn.Module):

    def __init__(self, params):
        super(MODEL, self).__init__()
        self.temporal_size = params.temporal_size
        self.image_mean = params.image_mean
        kernel_size = 3
        skip_kernel_size = 5
        weight_norm = torch.nn.utils.weight_norm
        num_inputs = params.num_channels
        if self.temporal_size:
            num_inputs *= self.temporal_size
        num_outputs = params.scale * params.scale * params.num_channels

        body = []
        conv = weight_norm(
            nn.Conv2d(
                num_inputs,
                params.num_residual_units,
                kernel_size,
                padding=kernel_size // 2))
        init.ones_(conv.weight_g)
        init.zeros_(conv.bias)
        body.append(conv)
        for _ in range(params.num_blocks):
            body.append(
                Block(
                    params.num_residual_units,
                    kernel_size,
                    params.width_multiplier,
                    weight_norm=weight_norm,
                    res_scale=1 / math.sqrt(params.num_blocks),
                ))
        conv = weight_norm(
            nn.Conv2d(
                params.num_residual_units,
                num_outputs,
                kernel_size,
                padding=kernel_size // 2))
        init.ones_(conv.weight_g)
        init.zeros_(conv.bias)
        body.append(conv)
        self.body = nn.Sequential(*body)

        skip = []
        if num_inputs != num_outputs:
            conv = weight_norm(
                nn.Conv2d(
                    num_inputs,
                    num_outputs,
                    skip_kernel_size,
                    padding=skip_kernel_size // 2))
            init.ones_(conv.weight_g)
            init.zeros_(conv.bias)
            skip.append(conv)
        self.skip = nn.Sequential(*skip)

        shuf = []
        if params.scale > 1:
            shuf.append(nn.PixelShuffle(params.scale))
        self.shuf = nn.Sequential(*shuf)

    def forward(self, x):
        if self.temporal_size:
            x = x.view([x.shape[0], -1, x.shape[3], x.shape[4]])
        x -= self.image_mean
        x = self.body(x) + self.skip(x)
        x = self.shuf(x)
        x += self.image_mean
        if self.temporal_size:
           x = x.view([x.shape[0], -1, 1, x.shape[2], x.shape[3]])
        return x


class Block(nn.Module):

    def __init__(self,
                 num_residual_units,
                 kernel_size,
                 width_multiplier=1,
                 weight_norm=torch.nn.utils.weight_norm,
                 res_scale=1):
        super(Block, self).__init__()
        body = []
        conv = weight_norm(
            nn.Conv2d(
                num_residual_units,
                int(num_residual_units * width_multiplier),
                kernel_size,
                padding=kernel_size // 2))
        init.constant_(conv.weight_g, 2.0)
        init.zeros_(conv.bias)
        body.append(conv)
        body.append(nn.ReLU(True))
        conv = weight_norm(
            nn.Conv2d(
                int(num_residual_units * width_multiplier),
                num_residual_units,
                kernel_size,
                padding=kernel_size // 2))
        init.constant_(conv.weight_g, res_scale)
        init.zeros_(conv.bias)
        body.append(conv)

        self.body = nn.Sequential(*body)

    def forward(self, x):
        x = self.body(x) + x
        return x
